Qual è il modo più semplice per trasformare il tensore di forma (batch_size, height, width) riempito con n valori in tensor of shape (batch_size, n, height, width)? Ho creato la soluzione di seguito, ma sembra che ci siano modi più semplici e veloci per farlo def batch_tensor_to_onehot (tnsr, classes): tnsr = tnsr.unsqueeze (1) res = [] per cls in range (classi): res.append ((tnsr == cls) .long ()) cat. torcia di ritorno (res, dim = 1)
2021-02-20 08:18:39
Puoi usare torch.nn.functional.one_hot. Per il tuo caso: a = torch.nn.functional.one_hot (tnsr, num_classes = classes) out = a.permute (0, 3, 1, 2) | Potresti anche usare Tensor.scatter_ che evita .permute ma è probabilmente più difficile da capire rispetto al metodo semplice proposto da @Alpha. def batch_tensor_to_onehot (tnsr, classes): risultato = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device) risultato.scatter_ (1, tnsr.unsqueeze (1), 1) risultato di ritorno Risultati del benchmarking Ero curioso e ho deciso di confrontare i tre approcci. Ho scoperto che non sembra esserci una differenza relativa significativa tra i metodi proposti per quanto riguarda la dimensione del lotto, la larghezza o l'altezza. Principalmente il numero di classi era il fattore distintivo. Ovviamente come con qualsiasi benchmark il chilometraggio può variare. I benchmark sono stati raccolti utilizzando indici casuali e utilizzando batch-size, altezza, larghezza = 100. Ciascun esperimento è stato ripetuto 20 volte riportando la media. L'esperimento num_classes = 100 viene eseguito una volta prima della creazione del profilo per il riscaldamento. I risultati della CPU mostrano che il metodo originale era probabilmente il migliore per num_classes inferiori a circa 30, mentre per GPU l'approccio scatter_ sembra essere il più veloce. Test eseguiti su Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K Il codice utilizzato per il benchmarking è fornito di seguito: importare torcia da tqdm importa tqdm tempo di importazione importa matplotlib.pyplot come plt def batch_tensor_to_onehot_slavka (tnsr, classes): tnsr = tnsr.unsqueeze (1) res = [] per cls in range (classi): res.append ((tnsr == cls) .long ()) cat. torcia di ritorno (res, dim = 1) def batch_tensor_to_onehot_alpha (tnsr, classes): risultato = torch.nn.functional.one_hot (tnsr, num_classes = classes) restituisce il risultato. permuto (0, 3, 1, 2) def batch_tensor_to_onehot_jodag (tnsr, classes): risultato = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device) risultato.scatter_ (1, tnsr.unsqueeze (1), 1) risultato di ritorno def main (): num_classes = [2, 10, 25, 50, 100] altezza = 100 larghezza = 100 bs = [100] * 20 for d in ['cpu', 'cuda']: times_slavka = [] times_alpha = [] times_jodag = [] warmup = vero per c in tqdm ([num_classes [-1]] + num_classes, ncols = 0): tslavka = 0 talfa = 0 tjodag = 0 per b in bs: tnsr = torch.randint (c, (b, height, width)). to (device = d) t0 = time.time () y = batch_tensor_to_onehot_slavka (tnsr, c) torch.cuda.synchronize () tslavka + = time.time () - t0 in caso contrario: times_slavka.append (tslavka / len (bs)) per b in bs: tnsr = torch.randint (c, (b, height, width)). to (device = d) t0 = time.time () y = batch_tensor_to_onehot_alpha (tnsr, c) torch.cuda.synchronize () talpha + = time.time () - t0 se non riscaldamento: times_alpha.append (talpha / len (bs)) per b in bs: tnsr = torch.randint (c, (b, height, width)). to (device = d) t0 = time.time () y = batch_tensor_to_onehot_jodag (tnsr, c) torch.cuda.synchronize () tjodag + = time.time () - t0 se non riscaldamento: times_jodag.append (tjodag / len (bs)) warmup = Falso fig = plt. figura () ax = fig.subplots () ax.plot (num_classes, times_slavka, label = 'Slavka-cat') ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot') ax.plot (num_classes, times_jodag, label = 'jodag-scatter_') ax.set_xlabel ('num_classes') ax.set_ylabel ('time (s)') ax.set_title (f '{d} benchmark') ax.legend () plt.savefig (f '{d} .png') plt. mostra () se __name__ == "__main__": principale() | La tua risposta StackExchange.ifUsing ("editor", function () { StackExchange.using ("externalEditor", function () { StackExchange.using ("snippets", function () { StackExchange.snippets.init (); }); }); }, "frammenti di codice"); StackExchange.ready (function () { var channelOptions = { tag: "" .split (""), id: "1" }; initTagRenderer ("". split (""), "" .split (""), channelOptions); StackExchange.using ("externalEditor", function () { // Devo attivare l'editor dopo gli snippet, se gli snippet sono abilitati if (StackExchange.settings.snippets.snippetsEnabled) { StackExchange.using ("snippets", function () { createEditor (); }); } altro { createEditor (); } }); function createEditor () { StackExchange.prepareEditor ({ useStacksEditor: false, heartbeatType: 'answer', autoActivateHeartbeat: false, convertImagesToLinks: true, noModals: true, showLowRepImageUploadWarning: true, reputationToPostImages: 10, bindNavPrevention: true, suffisso: "", imageUploader: { brandingHtml: "Powered by \ u003ca href = \" https: //imgur.com/ \ "\ u003e \ u003csvg class = \" svg-icon \ "width = \" 50 \ "height = \" 18 \ "viewBox = \ "0 0 50 18 \" fill = \ "none \" xmlns = \ "http: //www.w3.org/2000/svg \" \ u003e \ u003cpath d = \ "M46.1709 9.17788C46.1709 8.26454 46.2665 7.94324 47.1084 7.58816C47.4091 7.46349 47.7169 7.36433 48.0099 7.26993C48.9099 6.97997 49.672 6.73443 49.672 5.93063C49.672 5.22043 48.9832 4.61182 48.1414 4.61182C47.4335 4.62993C48.9099 6.97997 49.672 6.73443 49.672 5.93063C49.672 5.22043 48.9832 4.61182 48.1414 4.61182C47.4335 4.62481 46.72543 4.9162889 464562 46.72543 4.9162889 4.645.2543 4.6531 4.69562 4.65.695.65.69562 C4562 4.6531.495.65.69562 C4562.65.695.6531.495.69562 C4562. 43.1481 6.59048V11.9512C43.1481 13.2535 43.6264 13.8962 44.6595 13.8962C45.6924 13.8962 46.1709 13.253546.1709 11.9512V9.17788Z \ "/ \ u003e \ u003cpath d = \" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.5985 10.1419V6.59049C41.598532 41.59088 4.632 4.632 4.632 4.632 4.639 4.632 4.632 4.632 4.639 4.632 38.5948 5.28821 38.5948 6.59049V9.60062C38.5948 10.8521 38.2696 11.5455 37.0451 11.5455C35.8209 11.5455 35.4954 10.8521 35.4954 9.60062V6.59049C35.4954 5.28821 35.0173 4.66232 34.0034 4.66232C32V32.900.492.192.492.492.1329.49.492.492C32 fill-rule = \ "evenodd \" clip-rule = \ "evenodd \" d = \ "M25.6622 17.6335C27.8049 17.6335 29.3739 16.9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11.9512V6.59049C30.9615 5.28862 30.4 29.4502 4.66231C28.9913 4.66231 28.4555 4.94978 28.1109 5.50789C27.499 4.86533 26.7335 4.56087 25.7005 4.56087C23.1369 4.56087 21.0134 6.57349 21.0134 9.27932C21.0134 11.9852 23.003 13.713 25.3756 13.916.1 1226.722 1226.122 1226.122 1226.722 1226.722 1226.56 13.922 1226.722 1226.122 1226.722 1226.56 13.922 1226.19 C28. 1256 12.8854 28,1301 12,9342 28,1301 12.983C28.1301 14,4373 27,2502 15,2321 25,777 15.2321C24.8349 15,2321 24,1352 14,9821 23,5661 14.7787C23.176 14,6393 22,8472 14,5218 22,5437 14.5218C21.7977 14,5218 21,2429 15,0123 21,2429 15.6887C21.2429 16,7375 22,9072 17,6335 25,6622 17.6335ZM24.1317 9,27,932 mila C24.1317 7.94324 24.9928 7.09766 26.1024 7.09766C27.2119 7.09766 28.0918 7.94324 28.0918 9.27932C28.0918 10.6321 27.2311 11.5116 26.1024 11.5116C24.9737 11.5116 24.1317 10.6491 24.1317 9.27932Z \ "/ \ u003.percorso \ u45. 8045 13.2535 17.2637 13.8962 18.2965 13.8962C19.3298 13.8962 19.8079 13.2535 19.8079 11.9512V8.12928C19.8079 5.82936 18.4879 4.62866 16.4027 4.62866C15.1594 4.62866 14.279 4.98375 13.3609 5.88013C12.653 5.05154. 58314 4.9328 7.10506 4.66232 6.51203 4.66232C5.47873 4.66232 5.00066 5.28821 5.00066 6.59049V11.9512C5.00066 13.2535 5.47873 13.8962 6.51203 13.8962C7.54479 13.8962 8.0232 13 .2535 8.0232 11.9512V8.90741C8.0232 7.58817 8.44431 6.91179 9.53458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962C13.4375 13.8715.95 14.917.95 13.915.90 C13.4375 13.8715.95 13.915.379 C16.4027 6.91179 16.8045 7.58817 16.8045 8.94108V11.9512Z \ "/ \ u003e \ u003cpath d = \" M3.31675 6.59049C3.31675 5.28821 2.83866 4.66232 1.82471 4.66232C0.791758 4.66232 0.313354 5.28821 0.313354 6.590.25V1135 1.82471 13.8962C2.85798 13.8962 3.31675 13.2535 3.31675 11.9512V6.59049Z \ "/ \ u003e \ u003cpath d = \" M1.87209 0.400291C0.843612 0.400291 0 1.1159 0 1.98861C0 2.87869 0.822846 3.576700 1.87869 0.822846 3.57676 1.87869 0.822846 3.57676 1.87869 0.822846 3.57676 1.8720.772 C.772.772. C3.7234 1.1159 2.90056 0.400291 1.87209 0.400291Z \ "fill = \" # 1BB76E \ "/ \ u003e \ u003c / svg \ u003e \ u003c / a \ u003e", contentPolicyHtml: "Contributi degli utenti con licenza \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a \ u003e \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "\ u003e (content policy) \ u003c / a \ u003e", allowUrls: true }, onDemand: true, discardSelector: ".discard-answer" , immediatamenteShowMarkdownHelp: true, enableTables: true, enableSnippets: true }); } }); Grazie per aver contribuito con una risposta a Stack Overflow! Assicurati di rispondere alla domanda. Fornisci dettagli e condividi la tua ricerca! Ma evita ... Chiedere aiuto, chiarimenti o rispondere ad altre risposte. Fare dichiarazioni basate su opinioni; sostenerli con riferimenti o esperienza personale. Per saperne di più, consulta i nostri suggerimenti su come scrivere ottime risposte. Bozza salvata Bozza scartata Registrati o fai il login StackExchange.ready (function () { StackExchange.helpers.onClickDraftSave ('# login-link'); }); Registrati utilizzando Google Iscriviti utilizzando Facebook Iscriviti utilizzando e-mail e password Invia Pubblica come ospite Nome E-mail Obbligatorio, ma mai mostrato StackExchange.ready ( funzione () { StackExchange.openid.initPostLogin (". New-post-login", "https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer", "question_page" ); } ); Pubblica come ospite Nome E-mail Obbligatorio, ma mai mostrato Pubblica la tua risposta Scartare Facendo clic su "Pubblica la tua risposta", accetti i nostri termini di servizio, politica sulla privacy e politica sui cookie Non è la risposta che stai cercando? Sfoglia altre domande taggate python pytorch tensor one-hot-encoding o fai la tua domanda.